#include <smile/smile.h>
#include <algorithm>
#include <ctype.h>
#include <string>
#include <vector>
#include <math.h>

#include <exception>

#include "record.h"
#include "record_file.h"

void calculateLikelihood(std::string network, std::string recordFilename);

  int main(int argc, char **argv)
  {  
    if (argc != 3) {
        printf("usage: %s <net> <data>\n", argv[0]);
        exit(0);
    }
     calculateLikelihood(argv[1], argv[2]);
     return(DSL_OKAY);
  };

    // the description of valid SMILE identifiers is here:
    // http://genie.sis.pitt.edu/wiki/Reference_Manual:_DSL_header
  bool isNotAlphaNumUnderscore(char c) {
    if ( !isalnum(c) && (c!='_') ) {
        return true;
    }
    return false;
  }

  const char *makeValidId(std::string id) {
    // if it starts with a number, prepend 'x'
    if ( !isalpha(id[0]) ) {
        id = "x" + id;
    }

    std::replace_if(id.begin(), id.end(), isNotAlphaNumUnderscore, '_');
    std::transform(id.begin(), id.end(), id.begin(), ::tolower);
    return id.c_str();
  }

  int getHandle(DSL_network &theNet, std::string variable) {
    return theNet.FindNode(makeValidId(variable));
  }

  
  int getEvidenceIndex(DSL_network &theNet, int handle, std::string value) {
    const char* val = makeValidId(value);
    DSL_idArray *theNames;
    theNames = theNet.GetNode(handle)->Definition()->GetOutcomesNames();

    /*
    printf("query value: '%s'\n", val);

    printf("Size of names: %d\n", theNames->NumItems());
    for (int i = 0; i < theNames->NumItems(); i++) {
        printf("Name[%d]: '%s'\n", i, (*theNames)[i]);
        if (strcmp(val, (*theNames)[i]) == 0) {
            printf("The query and iteration compare equally.\n");
        }
    }
    */

    int evidenceIndex = theNames->FindPosition(val);
    return evidenceIndex;
  }

  void setEvidence(DSL_network &theNet, int handle, std::string value) {
    int evidenceIndex = getEvidenceIndex(theNet, handle, value);
    //printf("evidenceIndex: %d\n", evidenceIndex);
    theNet.GetNode(handle)->Value()->SetEvidence(evidenceIndex);
  }

  void calculateLikelihood(std::string network, std::string recordFilename) {
    DSL_network theNet;
    //printf("about to read network\n");
    theNet.ReadFile(network.c_str());
    //printf("read network\n");

    // read in the records
    datastructures::RecordFile recordFile(recordFilename, ',', true);
    recordFile.read();
    //printf("read records\n");

    // use jointree
    theNet.SetDefaultBNAlgorithm(DSL_ALG_BN_LAURITZEN);

    // first, get the handle to each node
    std::vector<int> handles;
    int variableCount = recordFile.getHeader().size();
    for (int i = 0; i < variableCount; i++) {
        int handle = getHandle(theNet, recordFile.getHeader().get(i));
        handles.push_back(handle);
    }

    //printf("have handles\n");

    double logLikelihood = 0.0;

    // now, for each record in the test set, calculate its log likelihood
    for (int i = 0; i < recordFile.size(); i++) {
        datastructures::Record record = recordFile.getRecord(i);

        for (int i = 0; i < variableCount; i++) {
            //printf("setting evidence, variable: %d, value: %s\n", i, record.get(i).c_str());
            setEvidence(theNet, handles[i], record.get(i));
        }

        double pe = 0;
        //printf("calculating evidence\n");
        try {
                bool res = theNet.CalcProbEvidence(pe);
                //printf("res: %d, the probability of the evidence is: %f\n", res, pe);

                double logPe = log(pe);
                logLikelihood += logPe;
                //printf("the log probability of the evidence is: %f\n", logPe);
        } catch (...) {
            printf("there was an exception\n");
        }

        theNet.ClearAllEvidence();
    }

    printf("The log likelihood of the data given the network is: %f\n", logLikelihood);


  };

